"""
Script for comparing OPT model performance between AMDSHARK and Huggingface
PyTorch.

Usage Example:

python opt_perf_comparison.py --max-seq-len=32 --model-name=facebook/opt-125m \
        --platform=amdshark

python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
        --platform=amdshark

See parse_args() below for command line argument usage.
"""

import argparse
import collections
import json
import os
import psutil
import time
import numpy as np
from typing import Tuple

from opt_util import PROMPTS
from amdshark.amdshark_inference import AMDSharkInference
from amdshark.amdshark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from amdshark_opt_wrapper import OPTForCausalLMModel
from amdshark.parser import amdshark_args
import iree.compiler as ireec

DEVICE = "cpu"
PLATFORM_AMDSHARK = "amdshark"
PLATFORM_HUGGINGFACE = "huggingface"

# Dict keys for reports.
REPORT_PLATFORM = "platform"
REPORT_MODEL_NAME = "model"
REPORT_MAX_SEQ_LEN = "max_seq_len"
REPORT_LOAD_TIME = "load_time_sec"
REPORT_RUN_TIME = "run_time_sec"
REPORT_LOAD_PHYSICAL_MEMORY_MB = "load_physical_MB"
REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"

ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])


def get_memory_info():
    pid = os.getpid()
    process = psutil.Process(pid)
    return process.memory_info()


def import_mlir_module(
    model_name: str,
    tokenizer,
    device: str,
    max_seq_len: int,
):
    opt_base_model = OPTForCausalLM.from_pretrained(
        model_name, ignore_mismatched_sizes=True
    )
    opt_base_model.eval()
    opt_model = OPTForCausalLMModel(opt_base_model)
    encoded_inputs = tokenizer(
        PROMPTS[0],
        padding="max_length",
        truncation=True,
        max_length=max_seq_len,
        return_tensors="pt",
    )
    inputs = (
        encoded_inputs["input_ids"],
        encoded_inputs["attention_mask"],
    )
    # np.save("model_inputs_0.npy", inputs[0])
    # np.save("model_inputs_1.npy", inputs[1])

    opt_fs_name = get_opt_fs_name(model_name)
    mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
    (model_mlir, func_name) = import_with_fx(
        model=opt_model,
        inputs=inputs,
        is_f16=False,
        model_name=opt_fs_name,
        return_str=True,
    )
    with open(mlir_path, "w") as f:
        f.write(model_mlir)
    print(f"Saved mlir at {mlir_path}")


def create_vmfb_module(
    model_name: str,
    tokenizer,
    device: str,
    max_seq_len: int,
    recompile_amdshark: bool,
):
    opt_fs_name = get_opt_fs_name(model_name)
    mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
    # If MLIR has already been loaded and recompilation is not requested, use
    # the loaded MLIR file.
    has_mlir = os.path.isfile(mlir_path)
    # The purpose of recompile_amdshark is to measure compilation time; the
    # compilation time can be correctly measured only when MLIR has already been
    # loaded.
    assert not recompile_amdshark or has_mlir
    if not has_mlir:
        import_mlir_module(
            model_name,
            tokenizer,
            device,
            max_seq_len,
        )
    amdshark_module = AMDSharkInference(
        mlir_path,
        device=device,
        mlir_dialect="tm_tensor",
        is_benchmark=False,
        rt_flags=[],
    )

    vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}"
    amdshark_module.save_module(module_name=vmfb_name)
    vmfb_path = vmfb_name + ".vmfb"
    return vmfb_path


def load_amdshark_model(
    model_name: str,
    token_model_name: str,
    max_seq_len: int,
    recompile_amdshark: bool,
    plugin_path: str = [],
) -> ModelWrapper:
    opt_fs_name = get_opt_fs_name(model_name)
    vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
    tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
    if recompile_amdshark or not os.path.isfile(vmfb_name):
        print(f"vmfb not found. compiling and saving to {vmfb_name}")
        create_vmfb_module(
            model_name, tokenizer, DEVICE, max_seq_len, recompile_amdshark
        )
    if plugin_path is not None:
        rt_flags = [f"--executable_plugin={plugin_path}"]
    else:
        rt_flags = []
    amdshark_module = AMDSharkInference(
        mlir_module=None, device="cpu-task", rt_flags=rt_flags
    )
    amdshark_module.load_module(vmfb_name)
    return ModelWrapper(model=amdshark_module, tokenizer=tokenizer)


def run_amdshark_model(model_wrapper: ModelWrapper, tokens):
    # Generate logits output of OPT model.
    return model_wrapper.model("forward", tokens)


def load_huggingface_model(
    model_name: str, token_model_name: str
) -> ModelWrapper:
    return ModelWrapper(
        model=OPTForCausalLM.from_pretrained(model_name),
        tokenizer=AutoTokenizer.from_pretrained(token_model_name),
    )


def run_huggingface_model(model_wrapper: ModelWrapper, tokens):
    return model_wrapper.model.forward(
        tokens.input_ids, tokens.attention_mask, return_dict=False
    )


def save_json(data, filename):
    with open(filename, "w") as file:
        json.dump(data, file)


def collect_huggingface_logits(
    model_name: str,
    token_model_name: str,
    max_seq_len: int,
    to_save_json: bool,
) -> Tuple[float, float]:
    # Load
    t0 = time.time()
    model_wrapper = load_huggingface_model(model_name, token_model_name)
    load_time = time.time() - t0
    print("--- Took {} seconds to load Huggingface.".format(load_time))
    load_memory_info = get_memory_info()

    results = []
    tokenized_prompts = []
    for prompt in PROMPTS:
        tokens = model_wrapper.tokenizer(
            prompt,
            padding="max_length",
            max_length=max_seq_len,
            truncation=True,
            return_tensors="pt",
        )
        tokenized_prompts.append(tokens)

    # Run
    t0 = time.time()
    for idx, tokens in enumerate(tokenized_prompts):
        print("prompt: {}".format(PROMPTS[idx]))
        logits = run_huggingface_model(model_wrapper, tokens)
        if to_save_json:
            results.append([PROMPTS[idx], logits[0].tolist()])
    run_time = time.time() - t0
    print("--- Took {} seconds to run Huggingface.".format(run_time))
    if to_save_json:
        save_json(results, "/tmp/huggingface.json")
    run_memory_info = get_memory_info()
    return {
        REPORT_PLATFORM: PLATFORM_HUGGINGFACE,
        REPORT_MODEL_NAME: model_name,
        REPORT_MAX_SEQ_LEN: max_seq_len,
        REPORT_LOAD_TIME: load_time,
        REPORT_RUN_TIME: run_time / len(PROMPTS),
        REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
        REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
        REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
        REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
    }


def collect_amdshark_logits(
    model_name: str,
    token_model_name: str,
    max_seq_len: int,
    recompile_amdshark: bool,
    to_save_json: bool,
    plugin_path: str,
) -> Tuple[float, float]:
    # Load
    t0 = time.time()
    model_wrapper = load_amdshark_model(
        model_name, token_model_name, max_seq_len, recompile_amdshark, plugin_path
    )
    load_time = time.time() - t0
    print("--- Took {} seconds to load AMDShark.".format(load_time))
    load_memory_info = get_memory_info()

    results = []
    tokenized_prompts = []
    for prompt in PROMPTS:
        tokens = model_wrapper.tokenizer(
            prompt,
            padding="max_length",
            truncation=True,
            max_length=max_seq_len,
            return_tensors="pt",
        )
        inputs = (
            tokens["input_ids"],
            tokens["attention_mask"],
        )
        tokenized_prompts.append(inputs)

    # Run
    t0 = time.time()
    for idx, tokens in enumerate(tokenized_prompts):
        print("prompt: {}".format(PROMPTS[idx]))
        logits = run_amdshark_model(model_wrapper, tokens)
        lst = [e.tolist() for e in logits]
        if to_save_json:
            results.append([PROMPTS[idx], lst])
    run_time = time.time() - t0
    print("--- Took {} seconds to run AMDShark.".format(run_time))
    if to_save_json:
        save_json(results, "/tmp/amdshark.json")
    platform_postfix = "-compile" if recompile_amdshark else "-precompiled"
    run_memory_info = get_memory_info()
    return {
        REPORT_PLATFORM: PLATFORM_AMDSHARK + platform_postfix,
        REPORT_MODEL_NAME: model_name,
        REPORT_MAX_SEQ_LEN: max_seq_len,
        REPORT_LOAD_TIME: load_time,
        REPORT_RUN_TIME: run_time / len(PROMPTS),
        REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
        REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
        REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
        REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
    }


def get_opt_fs_name(model_name: str) -> str:
    """Cleanses the model name ino a file system-friendly name.

    Example: get_opt_fs_name('facebook/opt-1.3b') == 'opt_1-3b'
    """
    slash_split = model_name.split("/")
    assert 1 <= len(slash_split) <= 2, "There should be at most one slash."
    model_name = slash_split[-1]
    for src_pattern, dest_pattern in (("-", "_"), (".", "-")):
        model_name = model_name.replace(src_pattern, dest_pattern)
    return model_name


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save-json",
        help="If set, saves output JSON.",
        action=argparse.BooleanOptionalAction,
        default=False,
    )
    parser.add_argument(
        "--max-seq-len", help="Max sequence length", type=int, default=32
    )
    parser.add_argument(
        "--model-name",
        help="Model name",
        type=str,
        choices=[
            "facebook/opt-125m",
            "facebook/opt-350m",
            "facebook/opt-1.3b",
            "facebook/opt-6.7b",
            "mit-han-lab/opt-125m-smoothquant",
            "mit-han-lab/opt-1.3b-smoothquant",
            "mit-han-lab/opt-2.7b-smoothquant",
            "mit-han-lab/opt-6.7b-smoothquant",
            "mit-han-lab/opt-13b-smoothquant",
        ],
        default="facebook/opt-1.3b",
    )
    parser.add_argument(
        "--recompile-amdshark",
        help="If set, recompiles MLIR",
        action=argparse.BooleanOptionalAction,
        default=False,
    )
    parser.add_argument(
        "--platform",
        help="Either amdshark or huggingface",
        type=str,
        choices=[PLATFORM_AMDSHARK, PLATFORM_HUGGINGFACE],
        default=PLATFORM_AMDSHARK,
    )
    parser.add_argument(
        "--plugin-path",
        help="path to executable plugin",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--token-model-name",
        help="HF ID to create tokenizer.",
        type=str,
        default=None,
    )
    args = parser.parse_args()
    print("args={}".format(args))
    return args


if __name__ == "__main__":
    args = parse_args()
    if args.token_model_name == None:
        if "smoothquant" in args.model_name:
            args.token_model_name = (
                f"facebook/opt-{args.model_name.split('-')[3]}"
            )
        else:
            args.token_model_name = args.model_name
    if args.platform == PLATFORM_AMDSHARK:
        amdshark_report = collect_amdshark_logits(
            args.model_name,
            args.token_model_name,
            args.max_seq_len,
            args.recompile_amdshark,
            args.save_json,
            args.plugin_path,
        )
        print("# Summary: {}".format(json.dumps(amdshark_report)))
    else:
        huggingface_report = collect_huggingface_logits(
            args.model_name,
            args.token_model_name,
            args.max_seq_len,
            args.save_json,
        )
        print("# Summary: {}".format(json.dumps(huggingface_report)))
